from pdbench.benchmark_methods.bioemu_sampler import BioemuSampler
from pdbench.benchmark_methods.chroma_sampler import ChromaSampler
from pdbench.benchmark_methods.esmdiff_sampler import ESMDiffSampler
from pdbench.benchmark_methods.alphaflow_sampler import AlphaFlowSampler
from pdbench.benchmark_methods.protenia_sampler import ProteniaSampler


def get_sampler(
    sampler_type: str
):
    assert sampler_type in ["chroma", "esmdiff", "alphaflow", "protenia", 'bioemu'], \
        f"{sampler_type} is not a valid sampler type"

    if sampler_type == "chroma":
        return ChromaSampler
    elif sampler_type == "esmdiff":
        return ESMDiffSampler
    elif sampler_type == "alphaflow":
        return AlphaFlowSampler
    elif sampler_type == "protenia":
        return ProteniaSampler
    elif sampler_type == "bioemu":
        return BioemuSampler